{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W_tvPdyfA-BL"
      },
      "source": [
        "##### Copyright 2018 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "0O_LFhwSBCjm"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PWUmcKKjtwXL"
      },
      "source": [
        "# Transfer learning with TensorFlow Hub\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/images/transfer_learning_with_hub\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/images/transfer_learning_with_hub.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/images/transfer_learning_with_hub.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/images/transfer_learning_with_hub.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4\"><img src=\"https://www.tensorflow.org/images/hub_logo_32px.png\" />See TF Hub model</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "crU-iluJIEzw"
      },
      "source": [
        "[TensorFlow Hub](https://tfhub.dev/) is a repository of pre-trained TensorFlow models.\n",
        "\n",
        "This tutorial demonstrates how to:\n",
        "\n",
        "1. Use models from TensorFlow Hub with `tf.keras`\n",
        "1. Use an image classification model from TensorFlow Hub\n",
        "1. Do simple transfer learning to fine-tune a model for your own image classes"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CKFUvuEho9Th"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OGNpmn43C0O6"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import time\n",
        "\n",
        "import PIL.Image as Image\n",
        "import matplotlib.pylab as plt\n",
        "\n",
        "import tensorflow as tf\n",
        "import tensorflow_hub as hub\n",
        "\n",
        "import datetime\n",
        "\n",
        "%load_ext tensorboard"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s4YuF5HvpM1W"
      },
      "source": [
        "## An ImageNet classifier\n",
        "\n",
        "You'll start by using a classifier model pre-trained on the [ImageNet](https://en.wikipedia.org/wiki/ImageNet) benchmark dataset—no initial training required!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xEY_Ow5loN6q"
      },
      "source": [
        "### Download the classifier\n",
        "\n",
        "Select a <a href=\"https://arxiv.org/abs/1801.04381\" class=\"external\">MobileNetV2</a> pre-trained model [from TensorFlow Hub](https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/2) and wrap it as a Keras layer with [`hub.KerasLayer`](https://www.tensorflow.org/hub/api_docs/python/hub/KerasLayer). Any <a href=\"https://tfhub.dev/s?q=tf2&module-type=image-classification/\" class=\"external\">compatible image classifier model</a> from TensorFlow Hub will work here, including the examples provided in the drop-down below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "feiXojVXAbI9"
      },
      "outputs": [],
      "source": [
        "mobilenet_v2 =\"https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4\"\n",
        "inception_v3 = \"https://tfhub.dev/google/imagenet/inception_v3/classification/5\"\n",
        "\n",
        "classifier_model = mobilenet_v2 #@param [\"mobilenet_v2\", \"inception_v3\"] {type:\"raw\"}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "y_6bGjoPtzau"
      },
      "outputs": [],
      "source": [
        "IMAGE_SHAPE = (224, 224)\n",
        "\n",
        "classifier = tf.keras.Sequential([\n",
        "    hub.KerasLayer(classifier_model, input_shape=IMAGE_SHAPE+(3,))\n",
        "])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pwZXaoV0uXp2"
      },
      "source": [
        "### Run it on a single image"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TQItP1i55-di"
      },
      "source": [
        "Download a single image to try the model on:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "w5wDjXNjuXGD"
      },
      "outputs": [],
      "source": [
        "grace_hopper = tf.keras.utils.get_file('image.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg')\n",
        "grace_hopper = Image.open(grace_hopper).resize(IMAGE_SHAPE)\n",
        "grace_hopper"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BEmmBnGbLxPp"
      },
      "outputs": [],
      "source": [
        "grace_hopper = np.array(grace_hopper)/255.0\n",
        "grace_hopper.shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0Ic8OEEo2b73"
      },
      "source": [
        "Add a batch dimension (with `np.newaxis`) and pass the image to the model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EMquyn29v8q3"
      },
      "outputs": [],
      "source": [
        "result = classifier.predict(grace_hopper[np.newaxis, ...])\n",
        "result.shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NKzjqENF6jDF"
      },
      "source": [
        "The result is a 1001-element vector of logits, rating the probability of each class for the image.\n",
        "\n",
        "The top class ID can be found with `tf.math.argmax`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rgXb44vt6goJ"
      },
      "outputs": [],
      "source": [
        "predicted_class = tf.math.argmax(result[0], axis=-1)\n",
        "predicted_class"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YrxLMajMoxkf"
      },
      "source": [
        "### Decode the predictions\n",
        "\n",
        "Take the `predicted_class` ID (such as `653`) and fetch the ImageNet dataset labels to decode the predictions:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ij6SrDxcxzry"
      },
      "outputs": [],
      "source": [
        "labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')\n",
        "imagenet_labels = np.array(open(labels_path).read().splitlines())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uzziRK3Z2VQo"
      },
      "outputs": [],
      "source": [
        "plt.imshow(grace_hopper)\n",
        "plt.axis('off')\n",
        "predicted_class_name = imagenet_labels[predicted_class]\n",
        "_ = plt.title(\"Prediction: \" + predicted_class_name.title())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "amfzqn1Oo7Om"
      },
      "source": [
        "## Simple transfer learning"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "K-nIpVJ94xrw"
      },
      "source": [
        "But what if you want to create a custom classifier using your own dataset that has classes that aren't included in the original ImageNet dataset (that the pre-trained model was trained on)?\n",
        "\n",
        "To do that, you can:\n",
        "\n",
        "1. Select a pre-trained model from TensorFlow Hub; and\n",
        "2. Retrain the top (last) layer to recognize the classes from your custom dataset."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z93vvAdGxDMD"
      },
      "source": [
        "### Dataset\n",
        "\n",
        "In this example, you will use the TensorFlow flowers dataset:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DrIUV3V0xDL_"
      },
      "outputs": [],
      "source": [
        "data_root = tf.keras.utils.get_file(\n",
        "  'flower_photos',\n",
        "  'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',\n",
        "   untar=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jFHdp18ccah7"
      },
      "source": [
        "First, load this data into the model using the image data off disk with `tf.keras.preprocessing.image_dataset_from_directory`, which will generate a `tf.data.Dataset`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mqnsczfLgcwv"
      },
      "outputs": [],
      "source": [
        "batch_size = 32\n",
        "img_height = 224\n",
        "img_width = 224\n",
        "\n",
        "train_ds = tf.keras.preprocessing.image_dataset_from_directory(\n",
        "  str(data_root),\n",
        "  validation_split=0.2,\n",
        "  subset=\"training\",\n",
        "  seed=123,\n",
        "  image_size=(img_height, img_width),\n",
        "  batch_size=batch_size\n",
        ")\n",
        "\n",
        "val_ds = tf.keras.preprocessing.image_dataset_from_directory(\n",
        "  str(data_root),\n",
        "  validation_split=0.2,\n",
        "  subset=\"validation\",\n",
        "  seed=123,\n",
        "  image_size=(img_height, img_width),\n",
        "  batch_size=batch_size\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cCrSRlomEIZ4"
      },
      "source": [
        "The flowers dataset has five classes:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AFgDHs6VEFRD"
      },
      "outputs": [],
      "source": [
        "class_names = np.array(train_ds.class_names)\n",
        "print(class_names)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L0Btd0V3C8h4"
      },
      "source": [
        "Second, because TensorFlow Hub's convention for image models is to expect float inputs in the `[0, 1]` range, use the `tf.keras.layers.experimental.preprocessing.Rescaling` layer to achieve this."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rs6gfO-ApTQW"
      },
      "source": [
        "Note: You could also include the `tf.keras.layers.experimental.preprocessing.Rescaling` layer inside the model. Refer to the [Working with preprocessing layers](https://www.tensorflow.org/guide/keras/preprocessing_layers) guide for a discussion of the tradeoffs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8NzDDWEMCL20"
      },
      "outputs": [],
      "source": [
        "normalization_layer = tf.keras.layers.experimental.preprocessing.Rescaling(1./255)\n",
        "train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y)) # Where x—images, y—labels.\n",
        "val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y)) # Where x—images, y—labels."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IW-BUJ-NC7y-"
      },
      "source": [
        "Third, finish the input pipeline by using buffered prefetching with `Dataset.prefetch`, so you can yield the data from disk without I/O blocking issues.\n",
        "\n",
        "These are some of the most important `tf.data` methods you should use when loading data. Interested readers can learn more about them, as well as how to cache data to disk and other techniques, in the [Better performance with the tf.data API](https://www.tensorflow.org/guide/data_performance#prefetching) guide."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZmJMKFw7C4ki"
      },
      "outputs": [],
      "source": [
        "AUTOTUNE = tf.data.AUTOTUNE\n",
        "train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)\n",
        "val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m0JyiEZ0imgf"
      },
      "outputs": [],
      "source": [
        "for image_batch, labels_batch in train_ds:\n",
        "  print(image_batch.shape)\n",
        "  print(labels_batch.shape)\n",
        "  break"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0gTN7M_GxDLx"
      },
      "source": [
        "### Run the classifier on a batch of images"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O3fvrZR8xDLv"
      },
      "source": [
        "Now, run the classifier on an image batch:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pcFeNcrehEue"
      },
      "outputs": [],
      "source": [
        "result_batch = classifier.predict(train_ds)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-wK2ky45hlyS"
      },
      "outputs": [],
      "source": [
        "predicted_class_names = imagenet_labels[tf.math.argmax(result_batch, axis=-1)]\n",
        "predicted_class_names"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QmvSWg9nxDLa"
      },
      "source": [
        "Check how these predictions line up with the images:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IXTB22SpxDLP"
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize=(10,9))\n",
        "plt.subplots_adjust(hspace=0.5)\n",
        "for n in range(30):\n",
        "  plt.subplot(6,5,n+1)\n",
        "  plt.imshow(image_batch[n])\n",
        "  plt.title(predicted_class_names[n])\n",
        "  plt.axis('off')\n",
        "_ = plt.suptitle(\"ImageNet predictions\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FUa3YkvhxDLM"
      },
      "source": [
        "Note: all images are licensed CC-BY, creators are listed in the LICENSE.txt file.\n",
        "\n",
        "The results are far from perfect, but reasonable considering that these are not the classes the model was trained for (except for \"daisy\")."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JzV457OXreQP"
      },
      "source": [
        "### Download the headless model\n",
        "\n",
        "TensorFlow Hub also distributes models without the top classification layer. These can be used to easily perform transfer learning.\n",
        "\n",
        "Select a <a href=\"https://arxiv.org/abs/1801.04381\" class=\"external\">MobileNetV2</a> pre-trained model <a href=\"https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4\" class=\"external\">from TensorFlow Hub</a>. Any <a href=\"https://tfhub.dev/s?module-type=image-feature-vector&q=tf2\" class=\"external\">compatible image feature vector model</a> from TensorFlow Hub will work here, including the examples from the drop-down menu."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4bw8Jf94DSnP"
      },
      "outputs": [],
      "source": [
        "mobilenet_v2 = \"https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4\"\n",
        "inception_v3 = \"https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4\"\n",
        "\n",
        "feature_extractor_model = mobilenet_v2 #@param [\"mobilenet_v2\", \"inception_v3\"] {type:\"raw\"}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sgwmHugQF-PD"
      },
      "source": [
        "Create the feature extractor by wrapping the pre-trained model as a Keras layer with [`hub.KerasLayer`](https://www.tensorflow.org/hub/api_docs/python/hub/KerasLayer). Use the `trainable=False` argument to freeze the variables, so that the training only modifies the new classifier layer:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5wB030nezBwI"
      },
      "outputs": [],
      "source": [
        "feature_extractor_layer = hub.KerasLayer(\n",
        "    feature_extractor_model,\n",
        "    input_shape=(224, 224, 3),\n",
        "    trainable=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0QzVdu4ZhcDE"
      },
      "source": [
        "The feature extractor returns a 1280-long vector for each image (the image batch size remains at 32 in this example):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Of7i-35F09ls"
      },
      "outputs": [],
      "source": [
        "feature_batch = feature_extractor_layer(image_batch)\n",
        "print(feature_batch.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RPVeouTksO9q"
      },
      "source": [
        "### Attach a classification head\n",
        "\n",
        "To complete the model, wrap the feature extractor layer in a `tf.keras.Sequential` model and add a fully-connected layer for classification:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vQq_kCWzlqSu"
      },
      "outputs": [],
      "source": [
        "num_classes = len(class_names)\n",
        "\n",
        "model = tf.keras.Sequential([\n",
        "  feature_extractor_layer,\n",
        "  tf.keras.layers.Dense(num_classes)\n",
        "])\n",
        "\n",
        "model.summary()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IyhX4VCFmzVS"
      },
      "outputs": [],
      "source": [
        "predictions = model(image_batch)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FQdUaTkzm3jQ"
      },
      "outputs": [],
      "source": [
        "predictions.shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OHbXQqIquFxQ"
      },
      "source": [
        "### Train the model\n",
        "\n",
        "Use `Model.compile` to configure the training process and add a `tf.keras.callbacks.TensorBoard` callback to create and store logs:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4xRx8Rjzm67O"
      },
      "outputs": [],
      "source": [
        "model.compile(\n",
        "  optimizer=tf.keras.optimizers.Adam(),\n",
        "  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "  metrics=['acc'])\n",
        "\n",
        "log_dir = \"logs/fit/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
        "tensorboard_callback = tf.keras.callbacks.TensorBoard(\n",
        "    log_dir=log_dir,\n",
        "    histogram_freq=1) # Enable histogram computation for every epoch."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "58-BLV7dupJA"
      },
      "source": [
        "Now use the `Model.fit` method to train the model.\n",
        "\n",
        "To keep this example short, you'll be training for just 10 epochs. To visualize the training progress in TensorBoard later, create and store logs an a [TensorBoard callback](https://www.tensorflow.org/tensorboard/get_started#using_tensorboard_with_keras_modelfit)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JI0yAKd-nARd"
      },
      "outputs": [],
      "source": [
        "NUM_EPOCHS = 10\n",
        "\n",
        "history = model.fit(train_ds,\n",
        "                    validation_data=val_ds,\n",
        "                    epochs=NUM_EPOCHS,\n",
        "                    callbacks=tensorboard_callback)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tiDbmiAK_h03"
      },
      "source": [
        "Start the TensorBoard to view how the metrics change with each epoch and to track other scalar values:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-yVJar0MiT2t"
      },
      "outputs": [],
      "source": [
        "%tensorboard --logdir logs/fit"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "36a9d7cab8c8"
      },
      "source": [
        "<!-- <img class=\"tfo-display-only-on-site\" src=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/images/images/tensorboard_transfer_learning_with_hub.png?raw=1\"/> -->"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kb__ZN8uFn-D"
      },
      "source": [
        "### Check the predictions\n",
        "\n",
        "Obtain the ordered list of class names from the model predictions:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JGbEf5l1I4jz"
      },
      "outputs": [],
      "source": [
        "predicted_batch = model.predict(image_batch)\n",
        "predicted_id = tf.math.argmax(predicted_batch, axis=-1)\n",
        "predicted_label_batch = class_names[predicted_id]\n",
        "print(predicted_label_batch)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CkGbZxl9GZs-"
      },
      "source": [
        "Plot the model predictions:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hW3Ic_ZlwtrZ"
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize=(10,9))\n",
        "plt.subplots_adjust(hspace=0.5)\n",
        "\n",
        "for n in range(30):\n",
        "  plt.subplot(6,5,n+1)\n",
        "  plt.imshow(image_batch[n])\n",
        "  plt.title(predicted_label_batch[n].title())\n",
        "  plt.axis('off')\n",
        "_ = plt.suptitle(\"Model predictions\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uRcJnAABr22x"
      },
      "source": [
        "## Export and reload your model\n",
        "\n",
        "Now that you've trained the model, export it as a SavedModel for reusing it later."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PLcqg-RmsLno"
      },
      "outputs": [],
      "source": [
        "t = time.time()\n",
        "\n",
        "export_path = \"/tmp/saved_models/{}\".format(int(t))\n",
        "model.save(export_path)\n",
        "\n",
        "export_path"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AhQ9liIUsPsi"
      },
      "source": [
        "Confirm that you can reload the SavedModel and that the model is able to output the same results:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7nI5fvkAQvbS"
      },
      "outputs": [],
      "source": [
        "reloaded = tf.keras.models.load_model(export_path)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dnZO14taYPH6"
      },
      "outputs": [],
      "source": [
        "result_batch = model.predict(image_batch)\n",
        "reloaded_result_batch = reloaded.predict(image_batch)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wtjsIPjQnPyM"
      },
      "outputs": [],
      "source": [
        "abs(reloaded_result_batch - result_batch).max()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jor83-LqI8xW"
      },
      "outputs": [],
      "source": [
        "reloaded_predicted_id = tf.math.argmax(reloaded_result_batch, axis=-1)\n",
        "reloaded_predicted_label_batch = class_names[reloaded_predicted_id]\n",
        "print(reloaded_predicted_label_batch)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RkQIBksVkxPO"
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize=(10,9))\n",
        "plt.subplots_adjust(hspace=0.5)\n",
        "for n in range(30):\n",
        "  plt.subplot(6,5,n+1)\n",
        "  plt.imshow(image_batch[n])\n",
        "  plt.title(reloaded_predicted_label_batch[n].title())\n",
        "  plt.axis('off')\n",
        "_ = plt.suptitle(\"Model predictions\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mSBRrW-MqBbk"
      },
      "source": [
        "## Next steps\n",
        "\n",
        "You can use the SavedModel to load for inference or convert it to a [TensorFlow Lite](https://www.tensorflow.org/lite/convert/)  model (for on-device machine learning) or a [TensorFlow.js](https://www.tensorflow.org/js/tutorials#convert_pretrained_models_to_tensorflowjs) model (for machine learning in JavaScript).\n",
        "\n",
        "Discover [more tutorials](https://www.tensorflow.org/hub/tutorials) to learn how to use pre-trained models from TensorFlow Hub on image, text, audio, and video tasks."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "W_tvPdyfA-BL"
      ],
      "name": "transfer_learning_with_hub.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
